import csv
import io
import json

import cvss.parser
from cvss.cvss3 import CVSS3

from dojo.models import Finding
from dojo.tools.sysdig_common.sysdig_data import SysdigData
from dojo.validators import clean_tags


class SysdigCLIParser:

    """Sysdig CLI Report Importer - Runtime CSV"""

    def get_scan_types(self):
        return ["Sysdig CLI Report"]

    def get_label_for_scan_types(self, scan_type):
        return "Sysdig CLI Report Scan"

    def get_description_for_scan_types(self, scan_type):
        return "Import of Sysdig Report generated by the Sysdig CLI scanner"

    def get_findings(self, filename, test):
        if filename is None:
            return ()
        if filename.name.lower().endswith(".csv"):
            arr_data = self.load_csv(filename)
            return self.parse_csv(arr_data=arr_data, test=test)
        if filename.name.lower().endswith(".json"):
            scan_data = filename.read()
            try:
                data = json.loads(str(scan_data, "utf-8"))
            except Exception:
                data = json.loads(scan_data)
                if "result" in data:
                    return self.parse_json(data=data, test=test)

                if "data" in data:
                    msg = "JSON file is not in the expected format, it looks like a Sysdig Vulnerability Report."
                else:
                    msg = "JSON file is not in the expected format, expected data result element"

                raise ValueError(msg)
        return ()

    def parse_json(self, data, test):
        findings = []
        packages = data.get("result", {}).get("packages", [])

        for package in packages:
            # print(package)
            packageName = package.get("name", "")
            packageType = package.get("type", "")
            packagePath = package.get("path", "")
            packageVersion = package.get("version", "")
            packageSuggestedFix = package.get("suggestedFix", "")
            layerDigest = package.get("layerDigest", "")

            vulns = package.get("vulns", [])
            # print("vulns: %s" % vulns)
            for item in vulns:
                # print("item: %s" % item)
                vulnName = item.get("name", "")
                vulnSeverity = SysdigData._map_severity(item.get("severity", {}).get("value", ""))
                vulnCvssScore = item.get("cvssScore", {}).get("value", {}).get("score", "")
                vulnCvssVersion = item.get("cvssScore", {}).get("value", {}).get("version", "")
                vulnCvssVector = item.get("cvssScore", {}).get("value", {}).get("vector", "")
                vulnDisclosureDate = item.get("disclosureDate", "")
                vulnSolutionDate = item.get("solutionDate", "")
                vulnPublishedByVendorDate = item.get("publishDateByVendor", {}).get("nvd", "")
                vulnExploitable = item.get("exploitable", "")
                vulnFixVersion = item.get("fixedInVersion", "")

                description = ""
                description += "vulnCvssVersion: " + vulnCvssVersion + "\n"
                description += "vulnCvssScore: " + str(vulnCvssScore) + "\n"
                description += "vulnCvssVector: " + vulnCvssVector + "\n"
                description += "vulnDisclosureDate: " + vulnDisclosureDate + "\n"
                description += "vulnPublishedByVendorDate: " + vulnPublishedByVendorDate + "\n"
                description += "vulnSolutionDate: " + vulnSolutionDate + "\n"
                description += "vulnExploitable: " + str(vulnExploitable) + "\n"
                description += "packageName: " + packageName + "\n"
                description += "packageType: " + packageType + "\n"
                description += "packagePath: " + packagePath + "\n"
                description += "packageVersion: " + packageVersion + "\n"
                description += "packageSuggestedFix: " + packageSuggestedFix + "\n"
                description += "layerDigest: " + layerDigest + "\n"

                mitigation = ""
                mitigation += "vulnFixVersion: " + vulnFixVersion + "\n"
                mitigation += "suggestedFix: " + packageSuggestedFix + "\n"

                finding = Finding(
                    title=vulnName + " - " + packageName + " - " + vulnFixVersion,
                    test=test,
                    description=description,
                    severity=vulnSeverity,
                    mitigation=mitigation,
                    static_finding=True,
                    component_name=packageName,
                    component_version=packageVersion,
                )

                try:
                    if float(vulnCvssVersion) >= 3 and float(vulnCvssVersion) < 4:
                        finding.cvssv3_score = vulnCvssScore
                        vectors = cvss.parser.parse_cvss_from_text(vulnCvssVector)
                        if len(vectors) > 0 and isinstance(vectors[0], CVSS3):
                            finding.cvssv3 = vectors[0].clean_vector()
                except ValueError:
                    continue

                if vulnName:
                    finding.unsaved_vulnerability_ids = []
                    finding.unsaved_vulnerability_ids.append(vulnName)
                findings.append(finding)
        return findings

    def parse_csv(self, arr_data, test):
        if len(arr_data) == 0:
            return ()
        sysdig_report_findings = []
        for row in arr_data:
            finding = Finding(test=test)
            # Generate finding
            finding.title = f"{row.vulnerability_id} - {row.package_name}"
            finding.vuln_id_from_tool = row.vulnerability_id
            finding.unsaved_vulnerability_ids = []
            finding.unsaved_vulnerability_ids.append(row.vulnerability_id)
            finding.severity = SysdigData._map_severity(row.severity)
            # Set Component Version
            finding.component_name = row.package_name
            finding.component_version = row.package_version
            # Set some finding tags
            tags = []
            if row.vulnerability_id:
                tags.append(clean_tags("VulnId:" + row.vulnerability_id))
            finding.tags = tags
            finding.dynamic_finding = False
            finding.static_finding = True
            finding.description += "\n\n###Vulnerability Details"
            finding.description += f"\n - **Vulnerability ID:** {row.vulnerability_id}"
            finding.description += f"\n - **Vulnerability Link:** {row.vuln_link}"
            finding.description += f"\n - **Severity:** {row.severity}"
            finding.description += f"\n - **Publish Date:** {row.vuln_publish_date}"
            finding.description += f"\n - **CVSS Version:** {row.cvss_version}"
            finding.description += f"\n - **CVSS Vector:** {row.cvss_vector}"
            if row.public_exploit:
                finding.description += f"\n - **Public Exploit:** {row.public_exploit}"
            finding.description += "\n\n###Package Details"
            if row.package_type == "os":
                finding.description += f"\n - **Package Type: {row.package_type} \\* Consider upgrading your Base OS \\***"
            else:
                finding.description += f"\n - **Package Type:** {row.package_type}"
            finding.description += f"\n - **Package Name:** {row.package_name}"
            finding.description += f"\n - **Package Version:** {row.package_version}"
            if row.package_path:
                finding.description += f"\n - **Package Path:** {row.package_path}"
                finding.file_path = row.package_path
            try:
                if float(row.cvss_version) >= 3 and float(row.cvss_version) < 4:
                    finding.cvssv3_score = float(row.cvss_score)
                    vectors = cvss.parser.parse_cvss_from_text(row.cvss_vector)
                    if len(vectors) > 0 and isinstance(vectors[0], CVSS3):
                        finding.cvssv3 = vectors[0].clean_vector()
            except ValueError:
                continue
            finding.risk_accepted = row.risk_accepted
            # Set reference
            if row.vuln_link:
                finding.references = row.vuln_link
                finding.url = row.vuln_link
            finding.epss_score = row.epss_score
            # finally, Add finding to list
            sysdig_report_findings.append(finding)
        return sysdig_report_findings

    def load_csv(self, filename) -> SysdigData:

        if filename is None:
            return ()

        content = filename.read()
        if isinstance(content, bytes):
            content = content.decode("utf-8")
        reader = csv.DictReader(io.StringIO(content), delimiter=",", quotechar='"')

        # normalise on lower case for consistency
        reader.fieldnames = [name.lower() for name in reader.fieldnames]

        csvarray = []

        for row in reader:
            # Compare headers to values.
            if len(row) != len(reader.fieldnames):
                msg = f"Number of fields in row ({len(row)}) does not match number of headers ({len(reader.fieldnames)})"
                raise ValueError(msg)

            # Check for a CVE value to being with
            if not row[reader.fieldnames[0]].startswith("CVE"):
                msg = f"Expected 'CVE' at the start but got: {row[reader.fieldnames[0]]}"
                raise ValueError(msg)

            csvarray.append(row)

        if "vulnerability id" in reader.fieldnames:
            msg = "Unknown CSV format: Vulnerability ID column found, looks like a SysDig Vulnerability Report"
            raise ValueError(msg)

        if "cve id" not in reader.fieldnames:
            msg = "Unknown CSV format: expected CVE ID column"
            raise ValueError(msg)

        arr_csv_data = []
        for row in csvarray:

            csv_data_record = SysdigData()
            msg = ""
            # Sydig CLI format
            csv_data_record.vulnerability_id = row.get("cve id", "")
            csv_data_record.severity = SysdigData._map_severity(row.get("cve severity").upper())
            csv_data_record.cvss_score = row.get("cvss score", "")
            csv_data_record.cvss_version = row.get("cvss score version", "")
            csv_data_record.package_name = row.get("package name", "")
            csv_data_record.package_version = row.get("package version", "")
            csv_data_record.package_type = row.get("package type", "")
            csv_data_record.package_path = row.get("package path", "")
            csv_data_record.vuln_fix_version = row.get("fix version", "")
            csv_data_record.vuln_link = row.get("cve url", "")
            csv_data_record.vuln_publish_date = row.get("vuln disclosure date", "")
            csv_data_record.vuln_fix_date = row.get("vuln fix date", "")
            csv_data_record.risk_accepted = row.get("risk accepted", "") == "TRUE"

            # new fields:
            csv_data_record.epss_score = row.get("epss score", "")

            # not present:
            # csv_data_record.public_exploit = row.get("public exploit", "")
            # csv_data_record.cvss_vector = row.get("cvss vector", "")
            # csv_data_record.image = row.get("image", "")
            # csv_data_record.os_name = row.get("os name", "")
            # csv_data_record.k8s_cluster_name = row.get("k8s cluster name", "")
            # csv_data_record.k8s_namespace_name = row.get("k8s namespace name", "")
            # csv_data_record.k8s_workload_type = row.get("k8s workload type", "")
            # csv_data_record.k8s_workload_name = row.get("k8s workload name", "")
            # csv_data_record.k8s_container_name = row.get("k8s container name", "")
            # csv_data_record.image_id = row.get("image id", "")
            # csv_data_record.k8s_pod_count = row.get("k8s pod count", "")
            # csv_data_record.package_suggested_fix = row.get("package suggested fix", "")
            # csv_data_record.in_use = row.get("in use", "") == "TRUE"
            # csv_data_record.registry_name = row.get("registry name", "")
            # csv_data_record.registry_image_repository = row.get("registry image repository", "")
            # csv_data_record.cloud_provider_name = row.get("cloud provider name", "")
            # csv_data_record.cloud_provider_account_id = row.get("cloud provider account ID", "")
            # csv_data_record.cloud_provider_region = row.get("cloud provider region", "")
            # csv_data_record.registry_vendor = row.get("registry vendor", "")

            arr_csv_data.append(csv_data_record)

        return arr_csv_data
